{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "authorship_tag": "ABX9TyOGB/4NLIJ2sfWa+yQFGuhp",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sugarforever/LangChain-Tutorials/blob/main/Langchain_HuggingFacePipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GRlP9ClDf1ZO",
        "outputId": "01ea1738-45ed-46b6-f9d0-21112d0eab6b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting transformers\n",
            "  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m40.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting langchain\n",
            "  Downloading langchain-0.0.244-py3-none-any.whl (1.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m35.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
            "Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)\n",
            "  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m23.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n",
            "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n",
            "  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting safetensors>=0.3.1 (from transformers)\n",
            "  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m52.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n",
            "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.19)\n",
            "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.8.5)\n",
            "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.2)\n",
            "Collecting dataclasses-json<0.6.0,>=0.5.7 (from langchain)\n",
            "  Downloading dataclasses_json-0.5.13-py3-none-any.whl (26 kB)\n",
            "Collecting langsmith<0.1.0,>=0.0.11 (from langchain)\n",
            "  Downloading langsmith-0.0.14-py3-none-any.whl (29 kB)\n",
            "Requirement already satisfied: numexpr<3.0.0,>=2.8.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.8.4)\n",
            "Collecting openapi-schema-pydantic<2.0,>=1.2 (from langchain)\n",
            "  Downloading openapi_schema_pydantic-1.2.4-py3-none-any.whl (90 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.0/90.0 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: pydantic<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.10.11)\n",
            "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.2.2)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.1.0)\n",
            "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (2.0.12)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.2)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.0)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
            "Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.6.0,>=0.5.7->langchain)\n",
            "  Downloading marshmallow-3.20.1-py3-none-any.whl (49 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.6.0,>=0.5.7->langchain)\n",
            "  Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.16)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
            "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain) (2.0.2)\n",
            "Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.6.0,>=0.5.7->langchain)\n",
            "  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
            "Installing collected packages: tokenizers, safetensors, mypy-extensions, marshmallow, typing-inspect, openapi-schema-pydantic, langsmith, huggingface-hub, transformers, dataclasses-json, langchain\n",
            "Successfully installed dataclasses-json-0.5.13 huggingface-hub-0.16.4 langchain-0.0.244 langsmith-0.0.14 marshmallow-3.20.1 mypy-extensions-1.0.0 openapi-schema-pydantic-1.2.4 safetensors-0.3.1 tokenizers-0.13.3 transformers-4.31.0 typing-inspect-0.9.0\n"
          ]
        }
      ],
      "source": [
        "!pip install transformers langchain"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain import HuggingFacePipeline\n",
        "\n",
        "model_id = \"facebook/bart-large-cnn\"\n",
        "task = \"summarization\"\n",
        "\n",
        "llm = HuggingFacePipeline.from_model_id(model_id=model_id, task=task)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bjEafRssgBrZ",
        "outputId": "359faf05-3796-44a1-8625-191aab4f918b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:langchain.llms.huggingface_pipeline:Device has 1 GPUs available. Provide device={deviceId} to `from_model_id` to use availableGPUs for execution. deviceId is -1 (default) for CPU and can be a positive integer associated with CUDA device id.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain import PromptTemplate, LLMChain\n",
        "\n",
        "template = \"\"\"{input}\"\"\"\n",
        "prompt = PromptTemplate(template=template, input_variables=[\"input\"])\n",
        "\n",
        "llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
        "\n",
        "question = \"\"\"\n",
        "The U.S. will be looking to snag their third straight World Cup title — and its fifth overall.\n",
        "\n",
        "The U.S. women's national team (USWNT) has held the No. 1 spot in FIFA's rankings for years, and is the odds-on favorite to win once again. But this year's tournament is considered fairly wide open, with several teams having a decent shot at the title.\n",
        "\n",
        "Indeed, the U.S. has had an interesting pattern of late: winning the World Cup, but losing at the Olympics. The USWNT won the World Cup in 2015, then failed to medal at the 2016 Olympics. They won the 2019 World Cup, then took home the bronze at the Tokyo Olympics in 2021.\n",
        "\"\"\"\n",
        "answer = llm_chain.run(question)"
      ],
      "metadata": {
        "id": "cAONy66ugTF2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "answer"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "p8Ea9uirJlAV",
        "outputId": "5de888e9-e615-465c-86f9-658345f48578"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\"The U.S. women's national team is the odds-on favorite to win this year's World Cup. The USWNT won the World Cup in 2015, then failed to medal at the 2016 Olympics. They won the 2019 World Cup, then took home the bronze at the Tokyo Olympics in 2021.\""
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(answer)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GGH563-IJnB4",
        "outputId": "1289a7ca-4c0d-4012-cebe-9425ec383ae4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "250"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(question)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WV-0jxQjJoYd",
        "outputId": "d5bc81fa-1f29-4cb9-a053-362c68260c3a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "625"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain import PromptTemplate\n",
        "template = \"\"\"\n",
        "你精通多种语言，是专业的翻译官。你负责{src_lang}到{dst_lang}的翻译工作。\n",
        "\"\"\"\n",
        "\n",
        "prompt = PromptTemplate.from_template(template)\n",
        "prompt.format(src_lang=\"英文\", dst_lang=\"中文\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        },
        "id": "iyfud0nDgFLm",
        "outputId": "2a7c8dce-5bac-47e1-9ae2-cd0084951246"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n你精通多种语言，是专业的翻译官。你负责英文到中文的翻译工作。\\n'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.prompts import (\n",
        "    ChatPromptTemplate,\n",
        "    PromptTemplate,\n",
        "    SystemMessagePromptTemplate,\n",
        "    AIMessagePromptTemplate,\n",
        "    HumanMessagePromptTemplate,\n",
        ")\n",
        "from langchain.schema import (\n",
        "    AIMessage,\n",
        "    HumanMessage,\n",
        "    SystemMessage\n",
        ")"
      ],
      "metadata": {
        "id": "do0XSG1Xo01P"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "template=\"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
        "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
        "system_message_prompt.format(input_language=\"English\", output_language=\"Japanese\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "N34Z4R7po3Qd",
        "outputId": "cc7c54f1-597b-4909-95ce-6f4084c05618"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "SystemMessage(content='You are a helpful assistant that translates English to Japanese.', additional_kwargs={})"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "system_template=\"You are a professional translator that translates {src_lang} to {dst_lang}.\"\n",
        "system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)\n",
        "\n",
        "human_template=\"{user_input}\"\n",
        "human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)\n",
        "\n",
        "chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])\n",
        "chat_prompt.format_prompt(\n",
        "    src_lang=\"English\",\n",
        "    dst_lang=\"Chinese\",\n",
        "    user_input=\"Did you eat in this morning?\"\n",
        ").to_messages()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "geN12guuqcoj",
        "outputId": "426ea0d8-4584-47f5-938f-caf454bb32c6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[SystemMessage(content='You are a professional translator that translates English to Chinese.', additional_kwargs={}),\n",
              " HumanMessage(content='Did you eat in this morning?', additional_kwargs={}, example=False)]"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.prompts import PromptTemplate\n",
        "from langchain.prompts import FewShotPromptTemplate\n",
        "from langchain.prompts.example_selector import LengthBasedExampleSelector\n",
        "\n",
        "\n",
        "examples = [\n",
        "    {\"input\": \"happy\", \"output\": \"sad\"},\n",
        "    {\"input\": \"tall\", \"output\": \"short\"},\n",
        "    {\"input\": \"energetic\", \"output\": \"lethargic\"},\n",
        "    {\"input\": \"sunny\", \"output\": \"gloomy\"},\n",
        "    {\"input\": \"windy\", \"output\": \"calm\"},\n",
        "]\n",
        "example_prompt = PromptTemplate(\n",
        "    input_variables=[\"input\", \"output\"],\n",
        "    template=\"Input: {input}\\nOutput: {output}\",\n",
        ")\n",
        "example_selector = LengthBasedExampleSelector(\n",
        "    # 可选的样本数据\n",
        "    examples=examples,\n",
        "    # 提示词模版\n",
        "    example_prompt=example_prompt,\n",
        "    # 格式化的样本数据的最大长度，通过get_text_length函数来衡量\n",
        "    max_length=25,\n",
        "    # get_text_length: ...\n",
        ")\n",
        "dynamic_prompt = FewShotPromptTemplate(\n",
        "    example_selector=example_selector,\n",
        "    example_prompt=example_prompt,\n",
        "    prefix=\"Give the antonym of every input\",\n",
        "    suffix=\"Input: {adjective}\\nOutput:\",\n",
        "    input_variables=[\"adjective\"],\n",
        ")\n",
        "\n",
        "dynamic_prompt.format(adjective=\"big\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 73
        },
        "id": "GSnhADzDwKvP",
        "outputId": "bacd7d10-4147-4067-f61a-3bb814c40ba4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Give the antonym of every input\\n\\nInput: happy\\nOutput: sad\\n\\nInput: tall\\nOutput: short\\n\\nInput: energetic\\nOutput: lethargic\\n\\nInput: sunny\\nOutput: gloomy\\n\\nInput: windy\\nOutput: calm\\n\\nInput: big\\nOutput:'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 19
        }
      ]
    }
  ]
}